{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "UmEFwsNs1OES"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Generate text embeddings by using the EmbeddingGemma model from Hugging Face\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/data_preprocessing/huggingface_text_embeddings.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/data_preprocessing/huggingface_text_embeddings.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ],
      "metadata": {
        "id": "ZUSiAR62SgO8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "Use text embeddings to represent text as numerical vectors. This process lets computers understand and process text data, which is essential for many natural language processing (NLP) tasks.\n",
        "\n",
        "The following NLP tasks use embeddings:\n",
        "\n",
        "* **Semantic search:** Find documents or passages that are relevant to a query when the query doesn't use the exact same words as the documents.\n",
        "* **Text classification:** Categorize text data into different classes, such as spam and not spam, or positive sentiment and negative sentiment.\n",
        "* **Machine translation:** Translate text from one language to another and preserve the meaning.\n",
        "* **Text summarization:** Create shorter summaries of text.\n",
        "\n",
        "This notebook uses Apache Beam's `MLTransform` to generate embeddings from text data.\n",
        "\n",
        "Using a small, highly efficient open model like EmbeddingGemma at the core of your pipeline makes the entire process self-contained, which can simplify management by eliminating the need for external network calls to other services for the embedding step. Because it's an open model, it can be hosted entirely within Dataflow. This provides the confidence to securely process large-scale, private datasets. For more information about the model, see the [model card](https://huggingface.co/google/embeddinggemma-300m)\n",
        "\n",
        "Hugging Face's [`SentenceTransformers`](https://huggingface.co/sentence-transformers) framework uses Python to generate sentence, text, and image embeddings.\n",
        "\n",
        "To generate text embeddings that use Hugging Face models and `MLTransform`, use the `SentenceTransformerEmbeddings` module to specify the model configuration.\n"
      ],
      "metadata": {
        "id": "yvVIEhF01ZWq"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Install dependencies\n",
        "\n",
        "Install Apache Beam and the dependencies needed to work with Hugging Face embeddings. The dependencies includes the `sentence-transformers` package, which is required to use the `SentenceTransformerEmbeddings` module.\n"
      ],
      "metadata": {
        "id": "jqYXaBJ821Zs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "! pip install apache_beam[interactive]>=2.53.0 --quiet\n",
        "! pip install sentence-transformers --quiet"
      ],
      "metadata": {
        "id": "shzCUrZI1XhF"
      },
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import tempfile\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.transforms.base import MLTransform\n",
        "from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformerEmbeddings"
      ],
      "metadata": {
        "id": "jVxSi2jS3M3b"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Authenticate with HuggingFace\n",
        "\n",
        "To ensure that you can pull the correct model, authenticate with HuggingFace by following the prompts in the cell."
      ],
      "metadata": {
        "id": "kXDM8C7d3nPW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!hf auth login"
      ],
      "metadata": {
        "id": "jVxSi2jS3M3c"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Process the data\n",
        "\n",
        "`MLTransform` is a `PTransform` that you can use for data preparation, including generating text embeddings.\n",
        "\n",
        "### Use MLTransform in write mode\n",
        "\n",
        "In `write` mode, `MLTransform` saves the transforms and their attributes to an artifact location. Then, when you run `MLTransform` in `read` mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy.\n",
        "\n",
        "For more information about using `MLTransform`, see [Preprocess data with MLTransform](https://beam.apache.org/documentation/ml/preprocess-data/) in the Apache Beam documentation."
      ],
      "metadata": {
        "id": "kXDM8C7d3nPV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get the data\n",
        "\n",
        "The following text inputs come from the Hugging Face blog [Getting Started With Embeddings](https://huggingface.co/blog/getting-started-with-embeddings).\n",
        "\n",
        "\n",
        "`MLTransform` operates on dictionaries of data. To generate embeddings for specific columns, provide the column names as input to the `columns` argument in the `SentenceTransformerEmbeddings` package."
      ],
      "metadata": {
        "id": "Dbkmu3HP6Kql"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "content = [\n",
        "    {'x': 'How do I get a replacement Medicare card?'},\n",
        "    {'x': 'What is the monthly premium for Medicare Part B?'},\n",
        "    {'x': 'How do I terminate my Medicare Part B (medical insurance)?'},\n",
        "    {'x': 'How do I sign up for Medicare?'},\n",
        "    {'x': 'Can I sign up for Medicare Part B if I am working and have health insurance through an employer?'},\n",
        "    {'x': 'How do I sign up for Medicare Part B if I already have Part A?'},\n",
        "    {'x': 'What are Medicare late enrollment penalties?'},\n",
        "    {'x': 'What is Medicare and who can get it?'},\n",
        "    {'x': 'How can I get help with my Medicare Part A and Part B premiums?'},\n",
        "    {'x': 'What are the different parts of Medicare?'},\n",
        "    {'x': 'Will my Medicare premiums be higher because of my higher income?'},\n",
        "    {'x': 'What is TRICARE ?'},\n",
        "    {'x': \"Should I sign up for Medicare Part B if I have Veterans' Benefits?\"}\n",
        "]\n",
        "\n",
        "text_embedding_model_name = 'google/embeddinggemma-300m'\n",
        "\n",
        "\n",
        "# helper function that returns a dict containing only first\n",
        "# ten elements of generated embeddings\n",
        "def truncate_embeddings(d):\n",
        "  for key in d.keys():\n",
        "    d[key] = d[key][:10]\n",
        "  return d"
      ],
      "metadata": {
        "id": "LCTUs8F73iDg"
      },
      "execution_count": 30,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "### Generate text embeddings\n",
        "This example uses the model `google/embeddinggemma-300m` to generate text embeddings. For more information about the model, see [the model card](https://huggingface.co/google/embeddinggemma-300m)."
      ],
      "metadata": {
        "id": "SApMmlRLRv_e"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "artifact_location_t5 = tempfile.mkdtemp(prefix='huggingface_')\n",
        "embedding_transform = SentenceTransformerEmbeddings(\n",
        "        model_name=text_embedding_model_name, columns=['x'])\n",
        "\n",
        "with beam.Pipeline() as pipeline:\n",
        "  data_pcoll = (\n",
        "      pipeline\n",
        "      | \"CreateData\" >> beam.Create(content))\n",
        "  transformed_pcoll = (\n",
        "      data_pcoll\n",
        "      | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location_t5).with_transform(embedding_transform))\n",
        "\n",
        "  transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n",
        "\n",
        "  transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SF6izkN134sf",
        "outputId": "524d3506-d31f-4dee-9079-1ed6d7cadf1a"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'x': [-0.0317193828523159, -0.005265399813652039, -0.012499183416366577, 0.00018130357784684747, -0.005592408124357462, 0.06207558885216713, -0.01656288281083107, 0.0167048592120409, -0.01239298190921545, 0.03041897714138031]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.015295305289328098, 0.005405726842582226, -0.015631258487701416, 0.022797023877501488, -0.027843449264764786, 0.03968179598450661, -0.004387892782688141, 0.022909151390194893, 0.01015392318367958, 0.04723235219717026]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.03450256213545799, -0.002632762538269162, -0.022460950538516045, -0.011689935810863972, -0.027329981327056885, 0.07293087989091873, -0.03069353476166725, 0.05429817736148834, -0.01308195199817419, 0.017668722197413445]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.02869587577879429, -0.0002648509689606726, -0.007186499424278736, -0.0003750955802388489, 0.012458174489438534, 0.06721009314060211, -0.013404129073023796, 0.03204648941755295, -0.021021844819188118, 0.04968355968594551]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.03241290897130966, 0.006845517549663782, 0.02001815102994442, -0.0057969288900494576, 0.008191823959350586, 0.08160955458879471, -0.009215254336595535, 0.023534387350082397, -0.02034241147339344, 0.0357462577521801]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.04592451825737953, -0.0025395643897354603, -0.01178023498505354, 0.011568977497518063, -0.0029014083556830883, 0.06971456110477448, -0.021167151629924774, 0.015902182087302208, -0.015007994137704372, 0.026213033124804497]}\n",
            "Embedding shape: 10\n",
            "{'x': [0.005221465136855841, -0.002127869985997677, -0.002369001042097807, -0.019337018951773643, 0.023243796080350876, 0.05599674955010414, -0.022721167653799057, 0.024813007563352585, -0.010685156099498272, 0.03624529018998146]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.035339221358299255, 0.010706206783652306, -0.001701260800473392, -0.00862252525985241, 0.006445988081395626, 0.08198338001966476, -0.022678885608911514, 0.01434261817485094, -0.008092232048511505, 0.03345781937241554]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.030748076736927032, 0.009340512566268444, -0.013637945055961609, 0.011183148249983788, -0.013879665173590183, 0.046350326389074326, -0.024090109393000603, 0.02885228954255581, -0.01699884608387947, 0.01672385260462761]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.040792081505060196, -0.00872269831597805, -0.015838179737329483, -0.03141209855675697, -7.104632823029533e-05, 0.08301416039466858, -0.034691162407398224, 0.0026397297624498606, 0.009255227632820606, 0.05415954813361168]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.02156883291900158, 0.003969342447817326, -0.030446071177721024, 0.008231461979448795, -0.01271845493465662, 0.03793857619166374, -0.013524272479116917, -0.0385628417134285, -0.0058258213102817535, 0.03505263477563858]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.027544165030121803, -0.01773364469408989, -0.013286487199366093, -0.008328652940690517, -0.011047529056668282, 0.05237515643239021, -0.016948163509368896, 0.02806701697409153, -0.0018120920285582542, 0.027241172268986702]}\n",
            "Embedding shape: 10\n",
            "{'x': [-0.03464886546134949, -0.003521248232573271, -0.010239562019705772, -0.018618224188685417, 0.004094886127859354, 0.062059685587882996, -0.013881963677704334, -0.0008639032603241503, -0.029874088242650032, 0.033531222492456436]}\n",
            "Embedding shape: 10\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "You can pass additional arguments that are supported by `sentence-transformer` models, such as `convert_to_numpy=False`. These arguments are passed as a `dict` to the `SentenceTransformerEmbeddings` transform by using the `inference_args` parameter.\n",
        "\n",
        "When you pass `convert_to_numpy=False`, the output contains `torch.Tensor` matrices."
      ],
      "metadata": {
        "id": "1MFom0PW_vRv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "artifact_location_t5_with_inference_args = tempfile.mkdtemp(prefix='huggingface_')\n",
        "\n",
        "embedding_transform = SentenceTransformerEmbeddings(\n",
        "        model_name=text_embedding_model_name, columns=['x'],\n",
        "        inference_args={'convert_to_numpy': False}\n",
        "        )\n",
        "\n",
        "with beam.Pipeline() as pipeline:\n",
        "  data_pcoll = (\n",
        "      pipeline\n",
        "      | \"CreateData\" >> beam.Create(content))\n",
        "  transformed_pcoll = (\n",
        "      data_pcoll\n",
        "      | \"MLTransform\" >> MLTransform(write_artifact_location=artifact_location_t5_with_inference_args).with_transform(embedding_transform))\n",
        "\n",
        "  # The outputs are in the PyTorch tensor type.\n",
        "  transformed_pcoll | 'LogOutput' >> beam.Map(lambda x: print(type(x['x'])))\n",
        "\n",
        "  transformed_pcoll | \"PrintEmbeddingShape\" >> beam.Map(lambda x: print(f\"Embedding shape: {len(x['x'])}\"))\n"
      ],
      "metadata": {
        "id": "xyezKuzY_uLD",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "96babb3b-c61b-40a2-96f7-572a2a46bd83"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n",
            "<class 'torch.Tensor'>\n",
            "Embedding shape: 768\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Use MLTransform in read mode\n",
        "\n",
        "In `read` mode, `MLTransform` uses the artifacts generated during `write` mode. In this case, the `SentenceTransformEmbedding` transform and its attributes are loaded from the saved artifacts. You don't need to specify the artifacts again during `read` mode.\n",
        "\n",
        "In this way, `MLTransform` provides consistent preprocessing steps for training and inference workloads."
      ],
      "metadata": {
        "id": "aPIQzCoF_EBj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "test_content = [\n",
        "    {\n",
        "        'x': 'This is a test sentence'\n",
        "    },\n",
        "    {\n",
        "        'x': 'The park is full of dogs'\n",
        "    },\n",
        "    {\n",
        "        'x': \"Should I sign up for Medicare Part B if I have Veterans' Benefits?\"\n",
        "    }\n",
        "]\n",
        "\n",
        "# Uses the T5 model to generate text embeddings\n",
        "with beam.Pipeline() as pipeline:\n",
        "  data_pcoll = (\n",
        "      pipeline\n",
        "      | \"CreateData\" >> beam.Create(test_content))\n",
        "  transformed_pcoll = (\n",
        "      data_pcoll\n",
        "      | \"MLTransform\" >> MLTransform(read_artifact_location=artifact_location_t5))\n",
        "\n",
        "  transformed_pcoll | beam.Map(truncate_embeddings) | 'LogOutput' >> beam.Map(print)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RCqYeUd3_F3C",
        "outputId": "4c5b61d2-df39-4af3-8520-f529f243e3b1"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'x': [0.00036313451710157096, -0.03929319977760315, -0.03574873134493828, 0.05015222355723381, 0.04295048117637634, 0.04800170287489891, 0.006883862894028425, -0.02567591704428196, -0.048067063093185425, 0.036534328013658524]}\n",
            "{'x': [-0.053793832659721375, 0.006730600260198116, -0.025130020454525948, 0.04363932088017464, 0.03323192894458771, 0.008803879842162132, -0.015412433072924614, 0.008926985785365105, -0.061175212264060974, 0.04573329910635948]}\n",
            "{'x': [-0.03464885801076889, -0.003521254053339362, -0.010239563882350922, -0.018618224188685417, 0.004094892647117376, 0.062059689313173294, -0.013881963677704334, -0.000863900815602392, -0.029874078929424286, 0.03353121876716614]}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Next Steps\n",
        "\n",
        "Now that you've generated embeddings, you can use MLTransform and Sinks to ingest your data into a Vector Database. For this, along with more advanced concepts, check out the following notebooks:\n",
        "\n",
        "- [Vector Embedding Ingestion with Apache Beam and AlloyDB](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/alloydb_product_catalog_embeddings.ipynb)\n",
        "- [Embedding Ingestion and Vector Search with Apache Beam and BigQuery](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/bigquery_vector_ingestion_and_search.ipynb)\n",
        "- [Vector Embedding Ingestion with Apache Beam and CloudSQL Postgres](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/cloudsql_postgres_product_catalog_embeddings.ipynb#scrollTo=K6-p-DVrIFTY)"
      ],
      "metadata": {
        "id": "l31V3Q0Uo41z"
      }
    }
  ]
}
